import logging
import os
import re
from typing import List

from langchain.chains import RetrievalQAWithSourcesChain, LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings, GPT4AllEmbeddings
from langchain.llms.llamacpp import LlamaCpp
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts import PromptTemplate
from langchain.retrievers import WebResearchRetriever
from langchain.schema.callbacks.manager import CallbackManager
from langchain.schema.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.utilities.google_search import GoogleSearchAPIWrapper
from langchain.vectorstores.chroma import Chroma
from pydantic import BaseModel, Field

vectorstore = Chroma(
    embedding_function=OpenAIEmbeddings(), persist_directory="./chroma_db_oai"
)

# LLM
llm = ChatOpenAI(temperature=0)

# Search
os.environ["GOOGLE_CSE_ID"] = "xxx"
os.environ["GOOGLE_API_KEY"] = "xxx"
search = GoogleSearchAPIWrapper()
# Initialize
web_research_retriever = WebResearchRetriever.from_llm(
    vectorstore=vectorstore,
    llm=llm,
    search=search,
)
user_input = "How do LLM Powered Autonomous Agents work?"
qa_chain = RetrievalQAWithSourcesChain.from_chain_type(
    llm, retriever=web_research_retriever
)
result = qa_chain({"question": user_input})
print(result)


# RetrievalQAWithSourcesChain

user_input = "How do LLM Powered Autonomous Agents work?"
qa_chain = RetrievalQAWithSourcesChain.from_chain_type(
    llm, retriever=web_research_retriever
)
result = qa_chain({"question": user_input})
print(result)


# 设置运行时日志等级
logging.basicConfig()
logging.getLogger("langchain.retrievers.web_research").setLevel(logging.INFO)
user_input = "What is Task Decomposition in LLM Powered Autonomous Agents?"
docs = web_research_retriever.get_relevant_documents(user_input)

# 使用load_qa_chain进行质量检查
chain = load_qa_chain(llm, chain_type="stuff")
output = chain({"input_documents": docs, "question": user_input}, return_only_outputs=True)
print(output["output"])

# 设置自定义提示和输出解析
search_prompt = PromptTemplate(
    input_variables=["question"],
    template="""You are an assistant tasked with improving Google search 
    results. Generate FIVE Google search queries that are similar to
    this question. The output should be a numbered list of questions and each
    should have a question mark at the end: {question}""",
)


class LineList(BaseModel):
    """List of questions."""

    lines: List[str] = Field(description="Questions")


class QuestionListOutputParser(PydanticOutputParser):
    """Output parser for a list of numbered questions."""

    def __init__(self) -> None:
        super().__init__(pydantic_object=LineList)

    def parse(self, text: str) -> LineList:
        lines = re.findall(r"\d+\..*?\n", text)
        return LineList(lines=lines)


llm_chain = LLMChain(
    llm=llm,
    prompt=search_prompt,
    output_parser=QuestionListOutputParser(),
)

web_research_retriever_llm_chain = WebResearchRetriever(
    vectorstore=vectorstore,
    llm_chain=llm_chain,
    search=search,
)
docs = web_research_retriever_llm_chain.get_relevant_documents(user_input)
print(docs)

# 运行本地示例
def run_locally():
    n_gpu_layers = 1  # Metal set to 1 is enough.
    n_batch = 512  # Should be between 1 and n_ctx, consider the amount of RAM of your Apple Silicon Chip.
    callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
    llama = LlamaCpp(
        model_path="/Users/rlm/Desktop/Code/llama.cpp/llama-2-13b-chat.ggmlv3.q4_0.bin",
        n_gpu_layers=n_gpu_layers,
        n_batch=n_batch,
        n_ctx=4096,  # Context window
        max_tokens=1000,  # Max tokens to generate
        f16_kv=True,  # MUST set to True, otherwise you will run into problem after a couple of calls
        callback_manager=callback_manager,
        verbose=True,
    )

    vectorstore_llama = Chroma(
        embedding_function=GPT4AllEmbeddings(), persist_directory="./chroma_db_llama"
    )
